{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 9. Recursive Neural Networks and Constituency Parsing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n",
    "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter, OrderedDict\n",
    "import nltk\n",
    "from copy import deepcopy\n",
    "import os\n",
    "from IPython.display import Image, display\n",
    "from nltk.draw import TreeWidget\n",
    "from nltk.draw.util import CanvasFrame\n",
    "from nltk.tree import Tree as nltkTree\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex = 0\n",
    "    eindex = batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n",
    "\n",
    "def draw_nltk_tree(tree):\n",
    "    cf = CanvasFrame()\n",
    "    tc = TreeWidget(cf.canvas(), tree)\n",
    "    tc['node_font'] = 'arial 15 bold'\n",
    "    tc['leaf_font'] = 'arial 15'\n",
    "    tc['node_color'] = '#005990'\n",
    "    tc['leaf_color'] = '#3F8F57'\n",
    "    tc['line_color'] = '#175252'\n",
    "    cf.add_widget(tc, 50, 50)\n",
    "    cf.print_to_file('tmp_tree_output.ps')\n",
    "    cf.destroy()\n",
    "    os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n",
    "    display(Image(filename='tmp_tree_output.png'))\n",
    "    os.system('rm tmp_tree_output.ps tmp_tree_output.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sample = random.choice(open('../dataset/trees/train.txt', 'r', encoding='utf-8').readlines())\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "draw_nltk_tree(nltkTree.fromstring(sample))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tree Class "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Node:  # a node in the tree\n",
    "    def __init__(self, label, word=None):\n",
    "        self.label = label\n",
    "        self.word = word\n",
    "        self.parent = None  # reference to parent\n",
    "        self.left = None  # reference to left child\n",
    "        self.right = None  # reference to right child\n",
    "        # true if I am a leaf (could have probably derived this from if I have\n",
    "        # a word)\n",
    "        self.isLeaf = False\n",
    "        # true if we have finished performing fowardprop on this node (note,\n",
    "        # there are many ways to implement the recursion.. some might not\n",
    "        # require this flag)\n",
    "\n",
    "    def __str__(self):\n",
    "        if self.isLeaf:\n",
    "            return '[{0}:{1}]'.format(self.word, self.label)\n",
    "        return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n",
    "\n",
    "\n",
    "class Tree:\n",
    "\n",
    "    def __init__(self, treeString, openChar='(', closeChar=')'):\n",
    "        tokens = []\n",
    "        self.open = '('\n",
    "        self.close = ')'\n",
    "        for toks in treeString.strip().split():\n",
    "            tokens += list(toks)\n",
    "        self.root = self.parse(tokens)\n",
    "        # get list of labels as obtained through a post-order traversal\n",
    "        self.labels = get_labels(self.root)\n",
    "        self.num_words = len(self.labels)\n",
    "\n",
    "    def parse(self, tokens, parent=None):\n",
    "        assert tokens[0] == self.open, \"Malformed tree\"\n",
    "        assert tokens[-1] == self.close, \"Malformed tree\"\n",
    "\n",
    "        split = 2  # position after open and label\n",
    "        countOpen = countClose = 0\n",
    "\n",
    "        if tokens[split] == self.open:\n",
    "            countOpen += 1\n",
    "            split += 1\n",
    "        # Find where left child and right child split\n",
    "        while countOpen != countClose:\n",
    "            if tokens[split] == self.open:\n",
    "                countOpen += 1\n",
    "            if tokens[split] == self.close:\n",
    "                countClose += 1\n",
    "            split += 1\n",
    "\n",
    "        # New node\n",
    "        node = Node(int(tokens[1]))  # zero index labels\n",
    "\n",
    "        node.parent = parent\n",
    "\n",
    "        # leaf Node\n",
    "        if countOpen == 0:\n",
    "            node.word = ''.join(tokens[2: -1]).lower()  # lower case?\n",
    "            node.isLeaf = True\n",
    "            return node\n",
    "\n",
    "        node.left = self.parse(tokens[2: split], parent=node)\n",
    "        node.right = self.parse(tokens[split: -1], parent=node)\n",
    "\n",
    "        return node\n",
    "\n",
    "    def get_words(self):\n",
    "        leaves = getLeaves(self.root)\n",
    "        words = [node.word for node in leaves]\n",
    "        return words\n",
    "\n",
    "def get_labels(node):\n",
    "    if node is None:\n",
    "        return []\n",
    "    return get_labels(node.left) + get_labels(node.right) + [node.label]\n",
    "\n",
    "def getLeaves(node):\n",
    "    if node is None:\n",
    "        return []\n",
    "    if node.isLeaf:\n",
    "        return [node]\n",
    "    else:\n",
    "        return getLeaves(node.left) + getLeaves(node.right)\n",
    "\n",
    "    \n",
    "def loadTrees(dataSet='train'):\n",
    "    \"\"\"\n",
    "    Loads training trees. Maps leaf node words to word ids.\n",
    "    \"\"\"\n",
    "    file = '../dataset/trees/%s.txt' % dataSet\n",
    "    print(\"Loading %s trees..\" % dataSet)\n",
    "    with open(file, 'r', encoding='utf-8') as fid:\n",
    "        trees = [Tree(l) for l in fid.readlines()]\n",
    "\n",
    "    return trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading train trees..\n"
     ]
    }
   ],
   "source": [
    "train_data = loadTrees('train')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Vocab "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab = list(set(flatten([t.get_words() for t in train_data])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word2index = {'<UNK>': 0}\n",
    "for vo in vocab:\n",
    "    if word2index.get(vo) is None:\n",
    "        word2index[vo] = len(word2index)\n",
    "        \n",
    "index2word = {v:k for k, v in word2index.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/09.rntn-layer.png\">\n",
    "<center>borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RNTN(nn.Module):\n",
    "    \n",
    "    def __init__(self, word2index, hidden_size, output_size):\n",
    "        super(RNTN,self).__init__()\n",
    "        \n",
    "        self.word2index = word2index\n",
    "        self.embed = nn.Embedding(len(word2index), hidden_size)\n",
    "#         self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n",
    "#         self.W = nn.Linear(hidden_size*2,hidden_size)\n",
    "        self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size * 2, hidden_size * 2)) for _ in range(hidden_size)]) # Tensor\n",
    "        self.W = nn.Parameter(torch.randn(hidden_size * 2, hidden_size))\n",
    "        self.b = nn.Parameter(torch.randn(1, hidden_size))\n",
    "#         self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n",
    "        self.W_out = nn.Linear(hidden_size, output_size)\n",
    "        \n",
    "    def init_weight(self):\n",
    "        nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n",
    "        nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n",
    "        for param in self.V.parameters():\n",
    "            nn.init.xavier_uniform(param)\n",
    "        nn.init.xavier_uniform(self.W)\n",
    "        self.b.data.fill_(0)\n",
    "#         nn.init.xavier_uniform(self.W_out)\n",
    "        \n",
    "    def tree_propagation(self, node):\n",
    "        \n",
    "        recursive_tensor = OrderedDict()\n",
    "        current = None\n",
    "        if node.isLeaf:\n",
    "            tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n",
    "                          else Variable(LongTensor([self.word2index['<UNK>']]))\n",
    "            current = self.embed(tensor) # 1xD\n",
    "        else:\n",
    "            recursive_tensor.update(self.tree_propagation(node.left))\n",
    "            recursive_tensor.update(self.tree_propagation(node.right))\n",
    "            \n",
    "            concated = torch.cat([recursive_tensor[node.left], recursive_tensor[node.right]], 1) # 1x2D\n",
    "            xVx = [] \n",
    "            for i, v in enumerate(self.V):\n",
    "#                 xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n",
    "                xVx.append(torch.matmul(torch.matmul(concated, v), concated.transpose(0, 1)))\n",
    "            \n",
    "            xVx = torch.cat(xVx, 1) # 1xD\n",
    "#             Wx = self.W(concated)\n",
    "            Wx = torch.matmul(concated, self.W) # 1xD\n",
    "\n",
    "            current = F.tanh(xVx + Wx + self.b) # 1xD\n",
    "        recursive_tensor[node] = current\n",
    "        return recursive_tensor\n",
    "        \n",
    "    def forward(self, Trees, root_only=False):\n",
    "        \n",
    "        propagated = []\n",
    "        if not isinstance(Trees, list):\n",
    "            Trees = [Trees]\n",
    "            \n",
    "        for Tree in Trees:\n",
    "            recursive_tensor = self.tree_propagation(Tree.root)\n",
    "            if root_only:\n",
    "                recursive_tensor = recursive_tensor[Tree.root]\n",
    "                propagated.append(recursive_tensor)\n",
    "            else:\n",
    "                recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n",
    "                propagated.extend(recursive_tensor)\n",
    "        \n",
    "        propagated = torch.cat(propagated) # (num_of_node in batch, D)\n",
    "        \n",
    "#         return F.log_softmax(propagated.matmul(self.W_out))\n",
    "        return F.log_softmax(self.W_out(propagated),1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "HIDDEN_SIZE = 30\n",
    "ROOT_ONLY = False\n",
    "BATCH_SIZE = 20\n",
    "EPOCH = 20\n",
    "LR = 0.01\n",
    "LAMBDA = 1e-5\n",
    "RESCHEDULED = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = RNTN(word2index, HIDDEN_SIZE,5)\n",
    "model.init_weight()\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0/20] mean_loss : 1.62\n",
      "[0/20] mean_loss : 1.25\n",
      "[0/20] mean_loss : 0.95\n",
      "[0/20] mean_loss : 0.90\n",
      "[0/20] mean_loss : 0.88\n",
      "[1/20] mean_loss : 0.88\n",
      "[1/20] mean_loss : 0.84\n",
      "[1/20] mean_loss : 0.83\n",
      "[1/20] mean_loss : 0.82\n",
      "[1/20] mean_loss : 0.82\n",
      "[2/20] mean_loss : 0.81\n",
      "[2/20] mean_loss : 0.79\n",
      "[2/20] mean_loss : 0.78\n",
      "[2/20] mean_loss : 0.76\n",
      "[2/20] mean_loss : 0.75\n",
      "[3/20] mean_loss : 0.68\n",
      "[3/20] mean_loss : 0.73\n",
      "[3/20] mean_loss : 0.74\n",
      "[3/20] mean_loss : 0.72\n",
      "[3/20] mean_loss : 0.72\n",
      "[4/20] mean_loss : 0.74\n",
      "[4/20] mean_loss : 0.69\n",
      "[4/20] mean_loss : 0.69\n",
      "[4/20] mean_loss : 0.68\n",
      "[4/20] mean_loss : 0.67\n",
      "[5/20] mean_loss : 0.73\n",
      "[5/20] mean_loss : 0.65\n",
      "[5/20] mean_loss : 0.64\n",
      "[5/20] mean_loss : 0.64\n",
      "[5/20] mean_loss : 0.65\n",
      "[6/20] mean_loss : 0.67\n",
      "[6/20] mean_loss : 0.62\n",
      "[6/20] mean_loss : 0.62\n",
      "[6/20] mean_loss : 0.62\n",
      "[6/20] mean_loss : 0.62\n",
      "[7/20] mean_loss : 0.57\n",
      "[7/20] mean_loss : 0.59\n",
      "[7/20] mean_loss : 0.59\n",
      "[7/20] mean_loss : 0.59\n",
      "[7/20] mean_loss : 0.59\n",
      "[8/20] mean_loss : 0.60\n",
      "[8/20] mean_loss : 0.58\n",
      "[8/20] mean_loss : 0.59\n",
      "[8/20] mean_loss : 0.60\n",
      "[8/20] mean_loss : 0.60\n",
      "[9/20] mean_loss : 0.52\n",
      "[9/20] mean_loss : 0.58\n",
      "[9/20] mean_loss : 0.60\n",
      "[9/20] mean_loss : 0.59\n",
      "[9/20] mean_loss : 0.59\n",
      "[10/20] mean_loss : 0.56\n",
      "[10/20] mean_loss : 0.56\n",
      "[10/20] mean_loss : 0.56\n",
      "[10/20] mean_loss : 0.56\n",
      "[10/20] mean_loss : 0.56\n",
      "[11/20] mean_loss : 0.52\n",
      "[11/20] mean_loss : 0.54\n",
      "[11/20] mean_loss : 0.54\n",
      "[11/20] mean_loss : 0.54\n",
      "[11/20] mean_loss : 0.55\n",
      "[12/20] mean_loss : 0.55\n",
      "[12/20] mean_loss : 0.53\n",
      "[12/20] mean_loss : 0.53\n",
      "[12/20] mean_loss : 0.53\n",
      "[12/20] mean_loss : 0.53\n",
      "[13/20] mean_loss : 0.59\n",
      "[13/20] mean_loss : 0.52\n",
      "[13/20] mean_loss : 0.52\n",
      "[13/20] mean_loss : 0.53\n",
      "[13/20] mean_loss : 0.53\n",
      "[14/20] mean_loss : 0.49\n",
      "[14/20] mean_loss : 0.51\n",
      "[14/20] mean_loss : 0.51\n",
      "[14/20] mean_loss : 0.52\n",
      "[14/20] mean_loss : 0.52\n",
      "[15/20] mean_loss : 0.43\n",
      "[15/20] mean_loss : 0.51\n",
      "[15/20] mean_loss : 0.51\n",
      "[15/20] mean_loss : 0.51\n",
      "[15/20] mean_loss : 0.51\n",
      "[16/20] mean_loss : 0.46\n",
      "[16/20] mean_loss : 0.50\n",
      "[16/20] mean_loss : 0.50\n",
      "[16/20] mean_loss : 0.50\n",
      "[16/20] mean_loss : 0.50\n",
      "[17/20] mean_loss : 0.50\n",
      "[17/20] mean_loss : 0.50\n",
      "[17/20] mean_loss : 0.50\n",
      "[17/20] mean_loss : 0.50\n",
      "[17/20] mean_loss : 0.51\n",
      "[18/20] mean_loss : 0.46\n",
      "[18/20] mean_loss : 0.50\n",
      "[18/20] mean_loss : 0.50\n",
      "[18/20] mean_loss : 0.49\n",
      "[18/20] mean_loss : 0.49\n",
      "[19/20] mean_loss : 0.49\n",
      "[19/20] mean_loss : 0.49\n",
      "[19/20] mean_loss : 0.49\n",
      "[19/20] mean_loss : 0.50\n",
      "[19/20] mean_loss : 0.50\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    losses = []\n",
    "    \n",
    "    # learning rate annealing\n",
    "    if RESCHEDULED == False and epoch == EPOCH//2:\n",
    "        LR *= 0.1\n",
    "        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=LAMBDA) # L2 norm\n",
    "        RESCHEDULED = True\n",
    "    \n",
    "    for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "        \n",
    "        if ROOT_ONLY:\n",
    "            labels = [tree.labels[-1] for tree in batch]\n",
    "            labels = Variable(LongTensor(labels))\n",
    "        else:\n",
    "            labels = [tree.labels for tree in batch]\n",
    "            labels = Variable(LongTensor(flatten(labels)))\n",
    "        \n",
    "        model.zero_grad()\n",
    "        preds = model(batch, ROOT_ONLY)\n",
    "        \n",
    "        loss = loss_function(preds, labels)\n",
    "        losses.append(loss.data.tolist()[0])\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % 100 == 0:\n",
    "            print('[%d/%d] mean_loss : %.2f' % (epoch, EPOCH, np.mean(losses)))\n",
    "            losses = []\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading test trees..\n"
     ]
    }
   ],
   "source": [
    "test_data = loadTrees('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "accuracy = 0\n",
    "num_node = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-grained all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In paper, they acheived 80.2 accuracy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "79.33705899068254\n"
     ]
    }
   ],
   "source": [
    "for test in test_data:\n",
    "    model.zero_grad()\n",
    "    preds = model(test, ROOT_ONLY)\n",
    "    labels = test.labels[-1:] if ROOT_ONLY else test.labels\n",
    "    for pred, label in zip(preds.max(1)[1].data.tolist(), labels):\n",
    "        num_node += 1\n",
    "        if pred == label:\n",
    "            accuracy += 1\n",
    "\n",
    "print(accuracy/num_node * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TODO "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Further topics "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* <a href=\"https://arxiv.org/pdf/1503.00075.pdf\">Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks</a>\n",
    "* <a href=\"https://arxiv.org/abs/1603.06021\">A Fast Unified Model for Parsing and Sentence Understanding(SPINN)</a>\n",
    "* <a href=\"https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/?utm_campaign=Revue%20newsletter&utm_medium=Newsletter&utm_source=revue\">Posting about SPINN</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
